{"cells":[{"cell_type":"markdown","metadata":{},"source":["### Checkout my [Twitter(@rohanpaul_ai)](https://twitter.com/rohanpaul_ai) for daily LLM bits"]},{"cell_type":"markdown","metadata":{},"source":["##  Fine Tuning Mistral-7B with PEFT and QLoRA\n","\n","# [Link to my Youtube Video Explaining this whole Notebook](https://www.youtube.com/watch?v=6DGYj1EEWOw&list=PLxqBkZuBynVTzqUQCQFgetR97y1X_1uCI&index=13&ab_channel=Rohan-Paul-AI)\n","\n","[![Imgur](https://imgur.com/MJ22PMV.png)](https://www.youtube.com/watch?v=6DGYj1EEWOw&list=PLxqBkZuBynVTzqUQCQFgetR97y1X_1uCI&index=13&ab_channel=Rohan-Paul-AI)"]},{"cell_type":"markdown","metadata":{},"source":["![](assets/2023-11-25-00-08-01.png)"]},{"cell_type":"code","execution_count":1,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:57:36.081861Z","iopub.status.busy":"2023-11-22T19:57:36.081116Z","iopub.status.idle":"2023-11-22T19:57:36.090426Z","shell.execute_reply":"2023-11-22T19:57:36.089534Z","shell.execute_reply.started":"2023-11-22T19:57:36.081827Z"},"trusted":true},"outputs":[],"source":["!pip install --upgrade peft accelerate bitsandbytes datasets trl"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:57:36.101514Z","iopub.status.busy":"2023-11-22T19:57:36.101273Z","iopub.status.idle":"2023-11-22T19:57:52.418945Z","shell.execute_reply":"2023-11-22T19:57:52.418028Z","shell.execute_reply.started":"2023-11-22T19:57:36.101492Z"},"trusted":true},"outputs":[],"source":["import os\n","from dataclasses import dataclass, field\n","from typing import Optional\n","from datasets.arrow_dataset import Dataset\n","import torch\n","from datasets import load_dataset\n","from peft import LoraConfig\n","from peft import AutoPeftModelForCausalLM\n","from transformers import (\n","    AutoModelForCausalLM,\n","    AutoTokenizer,\n","    BitsAndBytesConfig,\n","    HfArgumentParser,\n","    AutoTokenizer,\n","    TrainingArguments,\n",")\n","\n","from trl import SFTTrainer\n","\n","torch.manual_seed(42)"]},{"cell_type":"code","execution_count":3,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:57:52.420978Z","iopub.status.busy":"2023-11-22T19:57:52.420355Z","iopub.status.idle":"2023-11-22T19:57:52.432297Z","shell.execute_reply":"2023-11-22T19:57:52.431294Z","shell.execute_reply.started":"2023-11-22T19:57:52.420944Z"},"trusted":true},"outputs":[],"source":["@dataclass\n","class ScriptArguments:\n","    \"\"\"\n","    These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.\n","    \"\"\"\n","    local_rank: Optional[int] = -1\n","    per_device_train_batch_size: Optional[int] = 4\n","    per_device_eval_batch_size: Optional[int] = 4\n","    gradient_accumulation_steps: Optional[int] = 4\n","    learning_rate: Optional[float] = 2e-5\n","    max_grad_norm: Optional[float] = 0.3\n","    weight_decay: Optional[int] = 0.01\n","    lora_alpha: Optional[int] = 16\n","    lora_dropout: Optional[float] = 0.1\n","    lora_r: Optional[int] = 32\n","    max_seq_length: Optional[int] = 512\n","    # model_name: Optional[str] = \"bn22/Mistral-7B-Instruct-v0.1-sharded\"\n","    model_name: Optional[str] = \"mistralai/Mistral-7B-Instruct-v0.1\"\n","    dataset_name: Optional[str] = \"iamtarun/python_code_instructions_18k_alpaca\"\n","    use_4bit: Optional[bool] = True\n","    use_nested_quant: Optional[bool] = False\n","    bnb_4bit_compute_dtype: Optional[str] = \"float16\"\n","    bnb_4bit_quant_type: Optional[str] = \"nf4\"\n","    num_train_epochs: Optional[int] = 100\n","    fp16: Optional[bool] = False\n","    bf16: Optional[bool] = True\n","    packing: Optional[bool] = False\n","    gradient_checkpointing: Optional[bool] = True\n","    optim: Optional[str] = \"paged_adamw_32bit\"\n","    lr_scheduler_type: str = \"constant\"\n","    max_steps: int = 1000000\n","    warmup_ratio: float = 0.03\n","    group_by_length: bool = True\n","    save_steps: int = 50\n","    logging_steps: int = 50\n","    merge_and_push: Optional[bool] = False\n","    output_dir: str = \"./results_packing\""]},{"cell_type":"code","execution_count":4,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:57:52.435734Z","iopub.status.busy":"2023-11-22T19:57:52.435359Z","iopub.status.idle":"2023-11-22T19:57:52.453642Z","shell.execute_reply":"2023-11-22T19:57:52.452833Z","shell.execute_reply.started":"2023-11-22T19:57:52.435684Z"},"trusted":true},"outputs":[],"source":["# parser = HfArgumentParser(ScriptArguments)\n","# script_args = parser.parse_args_into_dataclasses()[0]"]},{"cell_type":"code","execution_count":5,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:57:52.455365Z","iopub.status.busy":"2023-11-22T19:57:52.454866Z","iopub.status.idle":"2023-11-22T19:57:52.464472Z","shell.execute_reply":"2023-11-22T19:57:52.463591Z","shell.execute_reply.started":"2023-11-22T19:57:52.455334Z"},"trusted":true},"outputs":[],"source":["script_args = ScriptArguments(\n","    local_rank=-1,\n","    per_device_train_batch_size=1,  # custom value\n","    per_device_eval_batch_size=1,\n","    gradient_accumulation_steps=4,\n","    learning_rate=3e-5,  # custom value\n","    max_grad_norm=0.3,\n","    weight_decay=0.01,\n","    lora_alpha=16,\n","    lora_dropout=0.1,\n","    lora_r=32,\n","    max_seq_length=512,\n","    # model_name=\"bn22/Mistral-7B-Instruct-v0.1-sharded\",\n","    model_name=\"mistralai/Mistral-7B-Instruct-v0.1\",\n","    dataset_name=\"iamtarun/python_code_instructions_18k_alpaca\",\n","    use_4bit=True,\n","    use_nested_quant=False,\n","    bnb_4bit_compute_dtype=\"float16\",\n","    bnb_4bit_quant_type=\"nf4\",\n","    num_train_epochs=100,\n","    fp16=True,\n","    bf16=False,\n","    packing=False,\n","    gradient_checkpointing=True,\n","    optim=\"paged_adamw_32bit\",\n","    lr_scheduler_type=\"constant\",\n","    max_steps=1000000,\n","    warmup_ratio=0.03,\n","    group_by_length=True,\n","    save_steps=50,\n","    logging_steps=50,\n","    merge_and_push=False,\n","    output_dir=\"./results_packing\"\n",")\n"]},{"cell_type":"markdown","metadata":{},"source":["## Data Preprocessing Utils"]},{"cell_type":"code","execution_count":6,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:57:52.481305Z","iopub.status.busy":"2023-11-22T19:57:52.481017Z","iopub.status.idle":"2023-11-22T19:57:52.499029Z","shell.execute_reply":"2023-11-22T19:57:52.498298Z","shell.execute_reply.started":"2023-11-22T19:57:52.481282Z"},"trusted":true},"outputs":[],"source":["def gen_batches_train():\n","    ds = load_dataset(script_args.dataset_name, streaming=True, split=\"train\")\n","    total_samples = 10000\n","    val_pct = 0.1\n","    train_limit = int(total_samples * (1 - val_pct))\n","    counter = 0\n","\n","    for sample in iter(ds):\n","        if counter >= train_limit:\n","            break\n","\n","        original_prompt = sample['prompt'].replace(\"### Input:\\n\", '').replace('# Python code\\n', '')\n","\n","        instruction_start = original_prompt.find(\"### Instruction:\") + len(\"### Instruction:\")\n","\n","        instruction_end = original_prompt.find(\"### Output:\")\n","\n","        instruction = original_prompt[instruction_start:instruction_end].strip()\n","\n","        content_start = original_prompt.find(\"### Output:\") + len(\"### Output:\")\n","\n","        content = original_prompt[content_start:].strip()\n","\n","        new_text_format = f'<s>[INST] {instruction} [/INST] ```python\\n{content}```</s>'\n","\n","        tokenized_output = tokenizer(new_text_format)\n","\n","        yield {'text': new_text_format}\n","\n","        counter += 1\n","\n","def gen_batches_val():\n","    ds = load_dataset(script_args.dataset_name, streaming=True, split=\"train\")\n","    total_samples = 10000\n","    val_pct = 0.1\n","    train_limit = int(total_samples * (1 - val_pct))\n","    counter = 0\n","\n","    for sample in iter(ds):\n","        if counter < train_limit:\n","            counter += 1\n","            continue\n","\n","        if counter >= total_samples:\n","            break\n","\n","        original_prompt = sample['prompt'].replace(\"### Input:\\n\", '').replace('# Python code\\n', '')\n","        instruction_start = original_prompt.find(\"### Instruction:\") + len(\"### Instruction:\")\n","        instruction_end = original_prompt.find(\"### Output:\")\n","\n","        instruction = original_prompt[instruction_start:instruction_end].strip()\n","        content_start = original_prompt.find(\"### Output:\") + len(\"### Output:\")\n","        content = original_prompt[content_start:].strip()\n","        new_text_format = f'<s>[INST] {instruction} [/INST] ```python\\n{content}```</s>'\n","\n","        tokenized_output = tokenizer(new_text_format)\n","        yield {'text': new_text_format}\n","\n","        counter += 1"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["\n","def create_and_prepare_model(args):\n","    compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)\n","\n","    bnb_config = BitsAndBytesConfig(\n","        load_in_4bit=args.use_4bit,\n","        bnb_4bit_quant_type=args.bnb_4bit_quant_type,\n","        bnb_4bit_compute_dtype=compute_dtype,\n","        bnb_4bit_use_double_quant=args.use_nested_quant,\n","    )\n","\n","    if compute_dtype == torch.float16 and args.use_4bit:\n","        major, _ = torch.cuda.get_device_capability()\n","        if major >= 8:\n","            print(\"=\" * 80)\n","            print(\"Your GPU supports bfloat16, you can accelerate training with the argument --bf16\")\n","            print(\"=\" * 80)\n","\n","    # Load the entire model on the GPU 0\n","    # switch to `device_map = \"auto\"` for multi-GPU\n","    device_map = {\"\": 0}\n","\n","    model = AutoModelForCausalLM.from_pretrained(\n","        args.model_name,\n","        quantization_config=bnb_config,\n","        device_map=device_map,\n","        # use_auth_token=True,\n","        # revision=\"refs/pr/35\"\n","    )\n","\n","    #### LLAMA STUFF\n","    # check: https://github.com/huggingface/transformers/pull/24906\n","    model.config.pretraining_tp = 1\n","    # model.config.\n","    #### LLAMA STUFF\n","    model.config.window = 256\n","\n","    peft_config = LoraConfig(\n","        lora_alpha=script_args.lora_alpha,\n","        lora_dropout=script_args.lora_dropout,\n","        # target_modules=[\"query_key_value\"],\n","        r=script_args.lora_r,\n","        bias=\"none\",\n","        task_type=\"CAUSAL_LM\",\n","        target_modules=[\n","        \"q_proj\",\n","        \"k_proj\",\n","        \"v_proj\",\n","        \"o_proj\",\n","    ],\n","    )\n","\n","    tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)\n","    tokenizer.pad_token = tokenizer.eos_token\n","\n","    return model, peft_config, tokenizer"]},{"cell_type":"code","execution_count":8,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:57:52.500253Z","iopub.status.busy":"2023-11-22T19:57:52.500015Z","iopub.status.idle":"2023-11-22T19:57:52.516188Z","shell.execute_reply":"2023-11-22T19:57:52.515464Z","shell.execute_reply.started":"2023-11-22T19:57:52.500232Z"},"trusted":true},"outputs":[],"source":["training_arguments = TrainingArguments(\n","    output_dir=script_args.output_dir,\n","    per_device_train_batch_size=script_args.per_device_train_batch_size,\n","    gradient_accumulation_steps=script_args.gradient_accumulation_steps,\n","    optim=script_args.optim,\n","    save_steps=script_args.save_steps,\n","    logging_steps=script_args.logging_steps,\n","    learning_rate=script_args.learning_rate,\n","    fp16=script_args.fp16,\n","    bf16=script_args.bf16,\n","    evaluation_strategy=\"steps\",\n","    max_grad_norm=script_args.max_grad_norm,\n","    max_steps=script_args.max_steps,\n","    warmup_ratio=script_args.warmup_ratio,\n","    group_by_length=script_args.group_by_length,\n","    lr_scheduler_type=script_args.lr_scheduler_type,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["model, peft_config, tokenizer = create_and_prepare_model(script_args)\n","model.config.use_cache = False"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Found cached dataset generator (/home/lin3060/.cache/huggingface/datasets/generator/default-d3a31d549d8b0b2f/0.0.0)\n","Found cached dataset generator (/home/lin3060/.cache/huggingface/datasets/generator/default-a3d63d20ba0037b9/0.0.0)\n"]}],"source":["train_gen = Dataset.from_generator(gen_batches_train)\n","\n","val_gen = Dataset.from_generator(gen_batches_val)"]},{"cell_type":"code","execution_count":12,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Dataset({\n","    features: ['text'],\n","    num_rows: 9000\n","})\n","Dataset({\n","    features: ['text'],\n","    num_rows: 1000\n","})\n"]}],"source":["print(train_gen)\n","\n","print(val_gen)"]},{"cell_type":"code","execution_count":13,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T19:59:59.403244Z","iopub.status.busy":"2023-11-22T19:59:59.402960Z","iopub.status.idle":"2023-11-22T19:59:59.412016Z","shell.execute_reply":"2023-11-22T19:59:59.411122Z","shell.execute_reply.started":"2023-11-22T19:59:59.403221Z"},"trusted":true},"outputs":[{"data":{"text/plain":["MistralForCausalLM(\n","  (model): MistralModel(\n","    (embed_tokens): Embedding(32000, 4096)\n","    (layers): ModuleList(\n","      (0-31): 32 x MistralDecoderLayer(\n","        (self_attn): MistralAttention(\n","          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)\n","          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)\n","          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)\n","          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)\n","          (rotary_emb): MistralRotaryEmbedding()\n","        )\n","        (mlp): MistralMLP(\n","          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)\n","          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)\n","          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)\n","          (act_fn): SiLUActivation()\n","        )\n","        (input_layernorm): MistralRMSNorm()\n","        (post_attention_layernorm): MistralRMSNorm()\n","      )\n","    )\n","    (norm): MistralRMSNorm()\n","  )\n","  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",")"]},"execution_count":13,"metadata":{},"output_type":"execute_result"}],"source":["model"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","# Fix weird overflow issue with fp16 training\n","tokenizer.padding_side = \"right\"\n","\n","trainer = SFTTrainer(\n","    model=model,\n","    train_dataset=train_gen,\n","    eval_dataset=val_gen,\n","    peft_config=peft_config,\n","    dataset_text_field=\"text\",\n","    max_seq_length=script_args.max_seq_length,\n","    tokenizer=tokenizer,\n","    args=training_arguments,\n","    packing=script_args.packing,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-11-22T20:00:21.944998Z","iopub.status.busy":"2023-11-22T20:00:21.944602Z","iopub.status.idle":"2023-11-22T20:03:56.063718Z","shell.execute_reply":"2023-11-22T20:03:56.062228Z","shell.execute_reply.started":"2023-11-22T20:00:21.944962Z"},"trusted":true},"outputs":[],"source":["trainer.train()"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.status.busy":"2023-11-22T20:03:56.065298Z","iopub.status.idle":"2023-11-22T20:03:56.066145Z","shell.execute_reply":"2023-11-22T20:03:56.065900Z","shell.execute_reply.started":"2023-11-22T20:03:56.065875Z"},"trusted":true},"outputs":[],"source":["if script_args.merge_and_push:\n","    output_dir = os.path.join(script_args.output_dir, \"final_checkpoints\")\n","    trainer.model.save_pretrained(output_dir)\n","\n","    # Free memory for merging weights\n","    del model\n","    torch.cuda.empty_cache()\n","\n","    model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map=\"auto\", torch_dtype=torch.bfloat16)\n","    model = model.merge_and_unload()\n","\n","    output_merged_dir = os.path.join(script_args.output_dir, \"Final_Model_Checkpoint\")\n","    model.save_pretrained(output_merged_dir, safe_serialization=True)"]},{"cell_type":"markdown","metadata":{},"source":["## Inference"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from transformers import AutoModelForCausalLM, AutoTokenizer\n","import torch\n","\n","# Load the fine-tuned model and tokenizer\n","model_path = \"./results_packing/Final_Model_Checkpoint\"  # Update this path to your model's location\n","tokenizer = AutoTokenizer.from_pretrained(model_path)\n","model = AutoModelForCausalLM.from_pretrained(model_path)\n","\n","# Function to generate text based on a prompt\n","def generate_text(prompt, max_length=50):\n","    # Encode the input prompt\n","    input_ids = tokenizer.encode(prompt, return_tensors='pt')\n","\n","    # Generate a response\n","    output = model.generate(input_ids, max_length=max_length, num_return_sequences=1)\n","\n","    # Decode and return the generated text\n","    return tokenizer.decode(output[0], skip_special_tokens=True)\n","\n","# Example usage\n","prompt = \"Your input prompt goes here\"  # Replace with your input prompt\n","generated_text = generate_text(prompt)\n","print(generated_text)\n"]}],"metadata":{"kaggle":{"accelerator":"gpu","dataSources":[],"dockerImageVersionId":30588,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"}},"nbformat":4,"nbformat_minor":4}
